{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unsloth Supervised Fine-Tuning\n",
    "\n",
    "This recipe allows TensorZero users to fine-tune models using [Unsloth](https://unsloth.ai) and their own data.\n",
    "Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.\n",
    "\n",
    "We demonstrate how to deploy a LoRA fine-tuned model for serverless inference using [Fireworks](https://fireworks.ai). Full instructions to deploy LoRA or full fine-tuned models are provided by [Fireworks](https://docs.fireworks.ai/fine-tuning/fine-tuning-models), [Together](https://docs.together.ai/docs/deploying-a-fine-tuned-model), and other inference providers. You can also use [vLLM](https://docs.vllm.ai/en/latest/examples/online_serving/api_client.html) to serve your fine-tuned model locally. The TensorZero client seemlessly integrates inference using your fine-tuned model for any of these approaches.\n",
    "\n",
    "To get started:\n",
    "\n",
    "- Set your `TENSORZERO_CLICKHOUSE_URL` enironment variable to point to the database containing the historical inferences you'd like to train on.\n",
    "- You'll also need to [install](https://docs.fireworks.ai/tools-sdks/firectl/firectl) the CLI tool `firectl` on your machine and sign in with `firectl signin`. You can test that this all worked with `firectl whoami`.\n",
    "- Update the following parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH = \"../../../examples/data-extraction-ner/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "METRIC_NAME = \"jaccard_similarity\"\n",
    "\n",
    "# The name of the variant to use to grab the templates used for fine-tuning\n",
    "TEMPLATE_VARIANT_NAME = \"gpt_4o_mini\"  # It's OK that this variant uses a different model than the one we're fine-tuning\n",
    "\n",
    "# If the metric is a float metric, you can set the threshold to filter the data\n",
    "FLOAT_METRIC_THRESHOLD = 0.5\n",
    "\n",
    "# Fraction of the data to use for validation\n",
    "VAL_FRACTION = 0.2\n",
    "\n",
    "# Maximum number of samples to use for fine-tuning\n",
    "MAX_SAMPLES = 100_000\n",
    "\n",
    "# Random seed\n",
    "SEED = 42"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Select a model to fine tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The name of the model to fine-tune (supported models: https://docs.unsloth.ai/get-started/all-our-models)\n",
    "MODEL_NAME = \"unsloth/Meta-Llama-3.1-8B-Instruct\"\n",
    "\n",
    "SERVERLESS = True  # Whether to use a serverless deployment. Set to False is full model fine tuning or using LoRA for a model without serverless support\n",
    "\n",
    "MAX_SEQ_LENGTH = 8192  # Choose any! Unsloth supports RoPE Scaling internally!\n",
    "\n",
    "MODEL_DTYPE = None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
    "\n",
    "LOAD_IN_4BIT = True  # Use 4bit quantization to reduce memory usage. Can be False."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose the appropriate chat template for the selected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from unsloth.chat_templates import CHAT_TEMPLATES\n",
    "\n",
    "print(list(CHAT_TEMPLATES.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose the chat template corresponding the the model you're fine-tuning.\n",
    "# For example, if you're fine-tuning \"unsloth/Meta-Llama-3.1-8B-Instruct\" you should use \"llama-3.1\"\n",
    "CHAT_TEMPLATE = \"llama-3.1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 1\n",
    "\n",
    "LEARNING_RATE = 2e-4\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optionally, use Low Rank Adaptation.\n",
    "\n",
    "Some [Fireworks Models]() support [serverless LoRA deployment](https://docs.fireworks.ai/fine-tuning/fine-tuning-models), but full fine-tuning usually needs some form of reserved capacity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Whether to use LoRA or not. Set to False for full model fine-tuning\n",
    "# If set to False, SEVERLESS must also be False as you will need to create your own deployment\n",
    "USE_LORA = True\n",
    "\n",
    "# LoRA Parameters\n",
    "LORA_R = 8  # LoRA rank (the bottleneck dimension in the adaptation matrices)\n",
    "LORA_ALPHA = 16  # LoRA scaling factor (sometimes set to 2x the rank)\n",
    "LORA_DROPOUT = 0.0  # Dropout rate applied to the LoRA layers (sometimes 0.05 or 0.1)\n",
    "LORA_TARGETS = [  # Which modules to inject LoRA into (often q_proj, v_proj, or all linear layers in attention)\n",
    "    \"q_proj\",\n",
    "    \"k_proj\",\n",
    "    \"v_proj\",\n",
    "    \"o_proj\",\n",
    "    \"gate_proj\",\n",
    "    \"up_proj\",\n",
    "    \"down_proj\",\n",
    "]\n",
    "LORA_BIAS = \"none\"  # Whether to add bias in LoRA adapters (rarely needed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "tensorzero_path = os.path.abspath(os.path.join(os.getcwd(), \"../../../\"))\n",
    "if tensorzero_path not in sys.path:\n",
    "    sys.path.append(tensorzero_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import tempfile\n",
    "from typing import Any, Dict\n",
    "\n",
    "import toml\n",
    "from datasets import Dataset\n",
    "from tensorzero import (\n",
    "    FloatMetricFilter,\n",
    "    TensorZeroGateway,\n",
    ")\n",
    "from tensorzero.util import uuid7\n",
    "from transformers import TrainingArguments\n",
    "from trl import SFTTrainer\n",
    "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
    "from unsloth.chat_templates import get_chat_template\n",
    "\n",
    "from recipes.util import tensorzero_rendered_samples_to_conversations, train_val_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load and render the stored inferences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorzero_client = TensorZeroGateway.build_embedded(\n",
    "    config_file=CONFIG_PATH,\n",
    "    clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"],\n",
    "    timeout=15,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the metric filter as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comparison_operator = \">=\"\n",
    "metric_node = FloatMetricFilter(\n",
    "    metric_name=METRIC_NAME,\n",
    "    value=FLOAT_METRIC_THRESHOLD,\n",
    "    comparison_operator=comparison_operator,\n",
    ")\n",
    "# from tensorzero import BooleanMetricFilter\n",
    "# metric_node = BooleanMetricFilter(\n",
    "#     metric_name=METRIC_NAME,\n",
    "#     value=True  # or False\n",
    "# )\n",
    "\n",
    "metric_node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Query the inferences and feedback from ClickHouse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stored_inferences = tensorzero_client.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    variant_name=None,\n",
    "    output_source=\"inference\",  # could also be \"demonstration\"\n",
    "    filters=metric_node,\n",
    "    limit=MAX_SAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Render the stored inferences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rendered_samples = tensorzero_client.experimental_render_samples(\n",
    "    stored_inferences=stored_inferences,\n",
    "    variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the data into training and validation sets for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples, eval_samples = train_val_split(\n",
    "    rendered_samples,\n",
    "    val_size=VAL_FRACTION,\n",
    "    last_inference_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert the rendered samples to conversations for tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_conversations = tensorzero_rendered_samples_to_conversations(train_samples)\n",
    "eval_conversations = tensorzero_rendered_samples_to_conversations(eval_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate the model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=MODEL_NAME,\n",
    "    max_seq_length=MAX_SEQ_LENGTH,\n",
    "    dtype=MODEL_DTYPE,\n",
    "    load_in_4bit=LOAD_IN_4BIT,\n",
    "    # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apply the chat completion template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = get_chat_template(\n",
    "    tokenizer,\n",
    "    chat_template=CHAT_TEMPLATE,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_conversations(inference: Dict[str, Any]):\n",
    "    inference.update({\"add_generation_prompt\": False, \"tokenize\": False})\n",
    "    return {\n",
    "        \"text\": tokenizer.apply_chat_template(\n",
    "            **inference,\n",
    "        )\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = Dataset.from_list([process_conversations(sample) for sample in train_conversations])\n",
    "eval_dataset = Dataset.from_list([process_conversations(sample) for sample in eval_conversations])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set LoRA parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_LORA:\n",
    "    model = FastLanguageModel.get_peft_model(\n",
    "        model,\n",
    "        r=LORA_R,\n",
    "        lora_alpha=LORA_ALPHA,\n",
    "        lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=LORA_TARGETS,\n",
    "        bias=LORA_BIAS,\n",
    "        # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n",
    "        use_gradient_checkpointing=\"unsloth\",  # True or \"unsloth\" for very long context\n",
    "        random_state=SEED,\n",
    "        use_rslora=False,  # Unsloth supports rank stabilized LoRA\n",
    "        loftq_config=None,  # And LoftQ\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    max_seq_length=MAX_SEQ_LENGTH,\n",
    "    dataset_num_proc=2,\n",
    "    packing=False,  # Can make training 5x faster for short sequences.\n",
    "    args=TrainingArguments(\n",
    "        eval_strategy=\"steps\",\n",
    "        eval_steps=20,\n",
    "        per_device_train_batch_size=BATCH_SIZE,\n",
    "        per_device_eval_batch_size=BATCH_SIZE,\n",
    "        gradient_accumulation_steps=1,\n",
    "        learning_rate=LEARNING_RATE,\n",
    "        weight_decay=0.01,\n",
    "        num_train_epochs=NUM_EPOCHS,  # Set this for 1 full training run.\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        warmup_steps=5,\n",
    "        logging_steps=10,\n",
    "        save_strategy=\"no\",\n",
    "        seed=SEED,\n",
    "        bf16=is_bfloat16_supported(),\n",
    "        fp16=not is_bfloat16_supported(),\n",
    "        optim=\"adamw_8bit\",\n",
    "        report_to=\"none\",  # Use this for WandB etc\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"UNSLOTH_RETURN_LOGITS\"] = \"1\"\n",
    "trainer_stats = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the model is done training, we need to [deploy](https://docs.fireworks.ai/fine-tuning/fine-tuning-models#deploying-and-using-a-model) it to Fireworks serverless inference. If you need high or guaranteed throughput you can also deploy the model to [reserved capacity](https://docs.fireworks.ai/deployments/reservations) or an on-demand [deployment](https://docs.fireworks.ai/guides/ondemand-deployments)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model_id = \"llama-v3p1-8b-instruct\"\n",
    "fine_tuned_model_id = f\"{MODEL_NAME.lower().replace('/', '-').replace('.', 'p')}-{str(uuid7()).split('-')[-1]}\"\n",
    "\n",
    "with tempfile.TemporaryDirectory() as tmpdirname:\n",
    "    tmpdirname = \"trainer_output\"\n",
    "    print(f\"Saving to temp dir: {tmpdirname}\")\n",
    "    model.save_pretrained(tmpdirname)\n",
    "    tokenizer.save_pretrained(tmpdirname)\n",
    "\n",
    "    base_model_path = f\"accounts/fireworks/models/{base_model_id}\"\n",
    "    command = [\n",
    "        \"firectl\",\n",
    "        \"create\",\n",
    "        \"model\",\n",
    "        fine_tuned_model_id,\n",
    "        tmpdirname,\n",
    "        \"--base-model\",\n",
    "        base_model_path,\n",
    "    ]\n",
    "    try:\n",
    "        result = subprocess.run(command, capture_output=True)\n",
    "        stdout = result.stdout.decode(\"utf-8\")\n",
    "        print(\"Command output:\", stdout)\n",
    "    except subprocess.CalledProcessError as e:\n",
    "        print(\"Error occurred:\", e.stderr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_id(stdout: str) -> str:\n",
    "    for line in stdout.splitlines():\n",
    "        if line.strip().startswith(\"Name:\"):\n",
    "            return line.split(\":\")[1].strip()\n",
    "    raise ValueError(\"Model ID not found in output\")\n",
    "\n",
    "\n",
    "model_identifier = get_model_id(stdout)\n",
    "\n",
    "model_identifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a deployment if not using a model with serverless support, if it does not support serveless addons, or if you are doing full fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not SERVERLESS:\n",
    "    command = [\"firectl\", \"create\", \"deployment\", model_identifier]\n",
    "    print(\" \".join(command))\n",
    "    result = subprocess.run(command, capture_output=True)\n",
    "    if result.returncode != 0:\n",
    "        print(result.stderr.decode(\"utf-8\"))\n",
    "    else:\n",
    "        stdout = result.stdout.decode(\"utf-8\")\n",
    "        print(stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the LoRA addon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_LORA:\n",
    "    command = [\"firectl\", \"load-lora\", model_identifier]\n",
    "    print(\" \".join(command))\n",
    "    result = subprocess.run(command, capture_output=True)\n",
    "    if result.returncode != 0:\n",
    "        print(result.stderr.decode(\"utf-8\"))\n",
    "    else:\n",
    "        stdout = result.stdout.decode(\"utf-8\")\n",
    "        print(stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model is deployed, you can add the fine-tuned model and a new variant to your config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = {\n",
    "    \"models\": {\n",
    "        model_identifier: {\n",
    "            \"routing\": [\"fireworks\"],\n",
    "            \"providers\": {\"fireworks\": {\"type\": \"fireworks\", \"model_name\": model_identifier}},\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "print(toml.dumps(model_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You're all set!\n",
    "\n",
    "You can change the weight to enable a gradual rollout of the new model."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
